import os
import sqlalchemy
import functools
import warnings

from sqlalchemy import event, inspect, orm
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm.session import Session as SessionBase
from threading import Lock

from .model import Model, DefaultMeta, _SQLAlchemyState


class _FakeSignal(object):
    def __init__(self, name, doc=None):
        self.name = name
        self.__doc__ = doc

    def send(self, *args, **kwargs):
        pass

    def _fail(self, *args, **kwargs):
        raise RuntimeError(
            "Signalling support is unavailable because the blinker"
            " library is not installed."
        )

    connect = connect_via = connected_to = temporarily_connected_to = _fail
    disconnect = _fail
    has_receivers_for = receivers_for = _fail
    del _fail

class Namespace(object):
    def signal(self, name, doc=None):
        return _FakeSignal(name, doc)


_signals = Namespace()
models_committed = _signals.signal('models-committed')
before_models_committed = _signals.signal('before-models-committed')

def engine_config_warning(config, version, deprecated_config_key, engine_option):
    if config[deprecated_config_key] is not None:
        warnings.warn(
            'The `{}` config option is deprecated and will be removed in'
            ' v{}.  Use `SQLALCHEMY_ENGINE_OPTIONS[\'{}\']` instead.'
            .format(deprecated_config_key, version, engine_option),
            DeprecationWarning
        )

def _set_default_query_class(d, cls):
    if 'query_class' not in d:
        d['query_class'] = cls

def _include_sqlalchemy(obj, cls):
    for module in sqlalchemy, sqlalchemy.orm:
        for key in module.__all__:
            if not hasattr(obj, key):
                setattr(obj, key, getattr(module, key))
    # Note: obj.Table does not attempt to be a SQLAlchemy Table class.
    obj.Table = _make_table(obj)
    obj.relationship = _wrap_with_default_query_class(obj.relationship, cls)
    obj.relation = _wrap_with_default_query_class(obj.relation, cls)
    obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls)
    obj.event = event

def _make_table(db):
    def _make_table(*args, **kwargs):
        if len(args) > 1 and isinstance(args[1], db.Column):
            args = (args[0], db.metadata) + args[1:]
        info = kwargs.pop('info', None) or {}
        info.setdefault('bind_key', None)
        kwargs['info'] = info
        return sqlalchemy.Table(*args, **kwargs)
    return _make_table

def _wrap_with_default_query_class(fn, cls):
    @functools.wraps(fn)
    def newfn(*args, **kwargs):
        _set_default_query_class(kwargs, cls)
        if "backref" in kwargs:
            backref = kwargs['backref']
            if isinstance(backref, string_types):
                backref = (backref, {})
            _set_default_query_class(backref[1], cls)
        return fn(*args, **kwargs)
    return newfn

def get_state(app):
        assert 'sqlalchemy' in app.extensions, \
            'The sqlalchemy extension was not registered to the current ' \
            'application.  Please make sure to call init_app() first.'
        return app.extensions['sqlalchemy']

def _record_queries(app):
    if app.debug:
        return True
    rq = app.config['SQLALCHEMY_RECORD_QUERIES']
    if rq is not None:
        return rq
    return bool(app.config.get('TESTING'))




class _QueryProperty(object):
    def __init__(self, sa):
        self.sa = sa

    def __get__(self, obj, type):
        try:
            mapper = orm.class_mapper(type)
            if mapper:
                return type.query_class(mapper, session=self.sa.session())
        except UnmappedClassError:
            return None

class _EngineConnector(object):

    def __init__(self, sa, app, bind=None):
        self._sa = sa
        self._app = app
        self._engine = None
        self._connected_for = None
        self._bind = bind
        self._lock = Lock()

    def get_uri(self):
        if self._bind is None:
            return self._app.config['SQLALCHEMY_DATABASE_URI']
        binds = self._app.config.get('SQLALCHEMY_BINDS') or ()
        assert self._bind in binds, \
            'Bind %r is not specified.  Set it in the SQLALCHEMY_BINDS ' \
            'configuration variable' % self._bind
        return binds[self._bind]

    def get_engine(self):
        with self._lock:
            uri = self.get_uri()
            echo = self._app.config['SQLALCHEMY_ECHO']
            if (uri, echo) == self._connected_for:
                return self._engine

            sa_url = make_url(uri)
            options = self.get_options(sa_url, echo)
            self._engine = rv = self._sa.create_engine(sa_url, options)

            if _record_queries(self._app):
                _EngineDebuggingSignalEvents(self._engine,
                                             self._app.import_name).register()

            self._connected_for = (uri, echo)

            return rv

    def get_options(self, sa_url, echo):
        options = {}

        self._sa.apply_pool_defaults(self._app, options)
        self._sa.apply_driver_hacks(self._app, sa_url, options)
        if echo:
            options['echo'] = echo

        # Give the config options set by a developer explicitly priority
        # over decisions FSA makes.
        options.update(self._app.config['SQLALCHEMY_ENGINE_OPTIONS'])

        # Give options set in SQLAlchemy.__init__() ultimate priority
        options.update(self._sa._engine_options)

        return options

class SignallingSession(SessionBase):
    def __init__(self, db, autocommit=False, autoflush=True, **options):
        #: The application that this session belongs to.
        self.app = app = db.get_app()
        track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS']
        bind = options.pop('bind', None) or db.engine
        binds = options.pop('binds', db.get_binds(app))

        if track_modifications is None or track_modifications:
            _SessionSignalEvents.register(self)

        SessionBase.__init__(
            self, autocommit=autocommit, autoflush=autoflush,
            bind=bind, binds=binds, **options
        )

    def get_bind(self, mapper=None, clause=None):
        """Return the engine or connection for a given model or
        table, using the ``__bind_key__`` if it is set.
        """
        # mapper is None if someone tries to just get a connection
        if mapper is not None:
            try:
                # SA >= 1.3
                persist_selectable = mapper.persist_selectable
            except AttributeError:
                # SA < 1.3
                persist_selectable = mapper.mapped_table

            info = getattr(persist_selectable, 'info', {})
            bind_key = info.get('bind_key')
            if bind_key is not None:
                state = get_state(self.app)
                return state.db.get_engine(self.app, bind=bind_key)
        return SessionBase.get_bind(self, mapper, clause)

class _SessionSignalEvents(object):
    @classmethod
    def register(cls, session):
        if not hasattr(session, '_model_changes'):
            session._model_changes = {}

        event.listen(session, 'before_flush', cls.record_ops)
        event.listen(session, 'before_commit', cls.record_ops)
        event.listen(session, 'before_commit', cls.before_commit)
        event.listen(session, 'after_commit', cls.after_commit)
        event.listen(session, 'after_rollback', cls.after_rollback)

    @classmethod
    def unregister(cls, session):
        if hasattr(session, '_model_changes'):
            del session._model_changes

        event.remove(session, 'before_flush', cls.record_ops)
        event.remove(session, 'before_commit', cls.record_ops)
        event.remove(session, 'before_commit', cls.before_commit)
        event.remove(session, 'after_commit', cls.after_commit)
        event.remove(session, 'after_rollback', cls.after_rollback)

    @staticmethod
    def record_ops(session, flush_context=None, instances=None):
        try:
            d = session._model_changes
        except AttributeError:
            return

        for targets, operation in ((session.new, 'insert'), (session.dirty, 'update'), (session.deleted, 'delete')):
            for target in targets:
                state = inspect(target)
                key = state.identity_key if state.has_identity else id(target)
                d[key] = (target, operation)

    @staticmethod
    def before_commit(session):
        try:
            d = session._model_changes
        except AttributeError:
            return

        if d:
            before_models_committed.send(session.app, changes=list(d.values()))

    @staticmethod
    def after_commit(session):
        try:
            d = session._model_changes
        except AttributeError:
            return

        if d:
            models_committed.send(session.app, changes=list(d.values()))
            d.clear()

    @staticmethod
    def after_rollback(session):
        try:
            d = session._model_changes
        except AttributeError:
            return

        d.clear()

class FSADeprecationWarning(DeprecationWarning):
    pass











class SQLAlchemy(object):
    def __init__(self, app=None, use_native_unicode=True, session_options=None,
                 metadata=None, query_class=orm.Query, model_class=Model,
                 engine_options=None):

        self.use_native_unicode = use_native_unicode
        self.Query = query_class
        self.session = self.create_scoped_session(session_options)
        self.Model = self.make_declarative_base(model_class, metadata)
        self._engine_lock = Lock()
        self.app = app
        self._engine_options = engine_options or {}
        _include_sqlalchemy(self, query_class)

        if app is not None:
            self.init_app(app)

    def create_scoped_session(self, options=None):
        if options is None:
            options = {'scopefunc': ''}

        scopefunc = options.pop('scopefunc')
        options.setdefault('query_cls', self.Query)
        return orm.scoped_session(
            self.create_session(options), scopefunc=scopefunc
        )

    def create_session(self, options):
        return orm.sessionmaker(class_=SignallingSession, db=self, **options)

    def make_declarative_base(self, model, metadata=None):
        if not isinstance(model, DeclarativeMeta):
            model = declarative_base(
                cls=model,
                name='Model',
                metadata=metadata,
                metaclass=DefaultMeta
            )

        if metadata is not None and model.metadata is not metadata:
            model.metadata = metadata

        if not getattr(model, 'query_class', None):
            model.query_class = self.Query

        model.query = _QueryProperty(self)
        return model

    def init_app(self, app):
        if (
            'SQLALCHEMY_DATABASE_URI' not in app.config and
            'SQLALCHEMY_BINDS' not in app.config
        ):
            warnings.warn(
                '在app.config中既没找到 SQLALCHEMY_DATABASE_URI 也没找到 SQLALCHEMY_BINDS'
                '默认设置 SQLALCHEMY_DATABASE_URI 成 "sqlite:///:memory:"'
            )

        app.config.setdefault('SQLALCHEMY_DATABASE_URI', 'sqlite:///:memory:')
        app.config.setdefault('SQLALCHEMY_BINDS', None)
        app.config.setdefault('SQLALCHEMY_NATIVE_UNICODE', None)
        app.config.setdefault('SQLALCHEMY_ECHO', False)
        app.config.setdefault('SQLALCHEMY_RECORD_QUERIES', None)
        app.config.setdefault('SQLALCHEMY_POOL_SIZE', None)
        app.config.setdefault('SQLALCHEMY_POOL_TIMEOUT', None)
        app.config.setdefault('SQLALCHEMY_POOL_RECYCLE', None)
        app.config.setdefault('SQLALCHEMY_MAX_OVERFLOW', None)
        app.config.setdefault('SQLALCHEMY_COMMIT_ON_TEARDOWN', False)
        app.config.setdefault('SQLALCHEMY_ENGINE_OPTIONS', {})
        track_modifications = app.config.setdefault(
            'SQLALCHEMY_TRACK_MODIFICATIONS', None
        )

        if track_modifications is None:
            warnings.warn(FSADeprecationWarning(
                'SQLALCHEMY_TRACK_MODIFICATIONS adds significant overhead and '
                'will be disabled by default in the future.  Set it to True '
                'or False to suppress this warning.'
            ))

        engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_SIZE', 'pool_size')
        engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_TIMEOUT', 'pool_timeout')
        engine_config_warning(app.config, '3.0', 'SQLALCHEMY_POOL_RECYCLE', 'pool_recycle')
        engine_config_warning(app.config, '3.0', 'SQLALCHEMY_MAX_OVERFLOW', 'max_overflow')

        app.extensions['sqlalchemy'] = _SQLAlchemyState(self)

    def create_all(self, app, bind='__all__'):
        self._execute_for_all_tables(app, bind, 'create_all')

    def drop_all(self, bind='__all__', app=None):
        self._execute_for_all_tables(app, bind, 'drop_all')

    def _execute_for_all_tables(self, app, bind, operation, skip_tables=False):
        app = self.get_app(app)

        if bind == '__all__':
            binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ())
        elif isinstance(bind, string_types) or bind is None:
            binds = [bind]
        else:
            binds = bind

        for bind in binds:
            extra = {}
            if not skip_tables:
                tables = self.get_tables_for_bind(bind)
                extra['tables'] = tables
            op = getattr(self.Model.metadata, operation)
            op(bind=self.get_engine(app, bind), **extra)

    def get_tables_for_bind(self, bind=None):
        result = []
        for table in self.Model.metadata.tables.values():
            if table.info.get('bind_key') == bind:
                result.append(table)
        return result

    def get_app(self, reference_app=None):
        if reference_app is not None:
            self.app = reference_app
            return reference_app

        if self.app is not None:
            return self.app
        
        raise RuntimeError(
            'No application found. Either work inside a view function or push'
            ' an application context. See'
            ' http://flask-sqlalchemy.pocoo.org/contexts/.'
        )

    def get_engine(self):
        with self._lock:
            uri = self.get_uri()
            echo = self._app.config['SQLALCHEMY_ECHO']
            if (uri, echo) == self._connected_for:
                return self._engine

            sa_url = make_url(uri)
            options = self.get_options(sa_url, echo)
            self._engine = rv = self._sa.create_engine(sa_url, options)

            if _record_queries(self._app):
                _EngineDebuggingSignalEvents(self._engine,
                                             self._app.import_name).register()

            self._connected_for = (uri, echo)

            return rv

    @property
    def engine(self):
        """Gives access to the engine.  If the database configuration is bound
        to a specific application (initialized with an application) this will
        always return a database connection.  If however the current application
        is used this might raise a :exc:`RuntimeError` if no application is
        active at the moment.
        """
        return self.get_engine()

    def make_connector(self, app=None, bind=None):
        """Creates the connector for a given state and bind."""
        return _EngineConnector(self, self.get_app(app), bind)

    def get_engine(self, app=None, bind=None):
        """Returns a specific engine."""

        app = self.get_app(app)
        state = get_state(app)

        with self._engine_lock:
            connector = state.connectors.get(bind)

            if connector is None:
                connector = self.make_connector(app, bind)
                state.connectors[bind] = connector

            return connector.get_engine()

    def apply_pool_defaults(self, app, options):
        def _setdefault(optionkey, configkey):
            value = app.config[configkey]
            if value is not None:
                options[optionkey] = value
        _setdefault('pool_size', 'SQLALCHEMY_POOL_SIZE')
        _setdefault('pool_timeout', 'SQLALCHEMY_POOL_TIMEOUT')
        _setdefault('pool_recycle', 'SQLALCHEMY_POOL_RECYCLE')
        _setdefault('max_overflow', 'SQLALCHEMY_MAX_OVERFLOW')


    def apply_driver_hacks(self, app, sa_url, options):
        if sa_url.drivername.startswith('mysql'):
            sa_url.query.setdefault('charset', 'utf8')
            if sa_url.drivername != 'mysql+gaerdbms':
                options.setdefault('pool_size', 10)
                options.setdefault('pool_recycle', 7200)
        elif sa_url.drivername == 'sqlite':
            pool_size = options.get('pool_size')
            detected_in_memory = False
            if sa_url.database in (None, '', ':memory:'):
                detected_in_memory = True
                from sqlalchemy.pool import StaticPool
                options['poolclass'] = StaticPool
                if 'connect_args' not in options:
                    options['connect_args'] = {}
                options['connect_args']['check_same_thread'] = False

                if pool_size == 0:
                    raise RuntimeError('SQLite in memory database with an '
                                       'empty queue not possible due to data '
                                       'loss.')
            elif not pool_size:
                from sqlalchemy.pool import NullPool
                options['poolclass'] = NullPool

            if not detected_in_memory:
                sa_url.database = os.path.join(app.root_path, sa_url.database)

        unu = app.config['SQLALCHEMY_NATIVE_UNICODE']
        if unu is None:
            unu = self.use_native_unicode
        if not unu:
            options['use_native_unicode'] = False

        if app.config['SQLALCHEMY_NATIVE_UNICODE'] is not None:
            warnings.warn(
                "The 'SQLALCHEMY_NATIVE_UNICODE' config option is deprecated and will be removed in"
                " v3.0.  Use 'SQLALCHEMY_ENGINE_OPTIONS' instead.",
                DeprecationWarning
            )
        if not self.use_native_unicode:
            warnings.warn(
                "'use_native_unicode' is deprecated and will be removed in v3.0."
                "  Use the 'engine_options' parameter instead.",
                DeprecationWarning
            )

    def create_engine(self, sa_url, engine_opts):
        """
            Override this method to have final say over how the SQLAlchemy engine
            is created.

            In most cases, you will want to use ``'SQLALCHEMY_ENGINE_OPTIONS'``
            config variable or set ``engine_options`` for :func:`SQLAlchemy`.
        """
        return sqlalchemy.create_engine(sa_url, **engine_opts)

    def get_binds(self, app=None):
        """Returns a dictionary with a table->engine mapping.

        This is suitable for use of sessionmaker(binds=db.get_binds(app)).
        """
        app = self.get_app(app)
        binds = [None] + list(app.config.get('SQLALCHEMY_BINDS') or ())
        retval = {}
        for bind in binds:
            engine = self.get_engine(app, bind)
            tables = self.get_tables_for_bind(bind)
            retval.update(dict((table, engine) for table in tables))
        return retval

    @property
    def metadata(self):
        return self.Model.metadata


    def reflect(self, bind='__all__', app=None):
        self._execute_for_all_tables(app, bind, 'reflect', skip_tables=True)

    def __repr__(self):
        return '<%s engine=%r>' % (
            self.__class__.__name__,
            self.engine.url if self.app or current_app else None
        )





